
import copy
import os

import numpy as np
import pickle

from matplotlib import pyplot as plt

from matplotlib.cm import get_cmap
plt.rc('axes', labelsize=16)  # fontsize of the x and y labels
plt.rc('legend', fontsize=12)  # legend fontsize

name = "tab10"
# name = "tab20"
cmap = get_cmap(name)  # type: matplotlib.colors.ListedColormap
base = list(cmap.colors)  # type: list
colors =[base[2], base[-1], base[3]]


lst= ['joint', 'cond_cov0', 'cond_cov1', 'covid', 'pneum']
# lmap={'joint':'$P(Covid,Pneumonia)$', 'cond_cov0':'$P(Pneumonia|Covid=0)$', 'cond_cov1':'$P(Pneumonia|Covid=1)$',
#       'covid':'$P(Covid)$', 'pneum':'$P(Pneumonia)$'}

lmap={'joint':'$P(Covid,Pneum)$', 'cond_cov0':'$P(Pneu|Covid=0)$', 'cond_cov1':'$P(Pneum|Covid=1)$',
      'covid':'$P(Covid)$', 'pneum':'$P(Pneum)$'}


#working
# exp_name="fixedDataset"
# exp_name="fixedDatasetRun2"  # working
# exp_name="fixedDatasetRun3"  # working
exp_name="fixedDatasetRun4"  # working
plot_title= 'Total variation distance after training on COVIDx CXR-3'

path = f"/SaveDir/{exp_name}/tvd"

delta=20
epochs=1000

new_tvd = {}
new_kl = {}
tvd_diff={}
tvd_error={}
for dist in lst:

    tvd_diff[dist]= np.load(f'{path}/{dist}.npy')
    xaxis = [i for i in range(len(tvd_diff[dist]))]


    new_tvd[dist], new_kl[dist] = [], []
    tvd_error[dist] = []
    idx = 0
    while (idx + 1) * delta <= min(epochs, tvd_diff[dist].shape[0]):
        st, en= idx * delta, (idx + 1) * delta
        new_tvd[dist].append(np.mean(tvd_diff[dist][st: en]))

        # tvd
        error=  abs(tvd_diff[dist][idx * delta: (idx + 1) * delta] - new_tvd[dist][-1])
        tvd_error[dist].append(np.mean(error))
        idx += 1
    xaxis = [i * delta for i in range(len(new_tvd[dist]))]

    if dist=='ATE':
        plt.plot(xaxis, new_tvd[dist], label=lmap[dist], linestyle='dashed')  # 'solid', 'dashed', 'dashdot', 'dotted'
    else:
        plt.plot(xaxis, new_tvd[dist], label=lmap[dist])  # 'solid', 'dashed', 'dashdot', 'dotted'

    y, e = np.array(new_tvd[dist]), np.array(tvd_error[dist])
    plt.fill_between(xaxis, y - e, y + e, alpha=0.2)


    print('-->', dist)
    for ii in xaxis:
        if ii%delta==0:
            print(ii, new_tvd[dist][int(ii/delta)])

    # print('last', np.mean(tvd_diff[dist][-delta:-1]))
    print('last', (tvd_diff[dist][-delta:-1]))

#



plt.xlabel("Epochs")
plt.ylabel("Total variation distance (TVD)")
plt.legend()
ax = plt.subplot(111)
ncol = 2
ax.legend(loc='upper center', bbox_to_anchor=(0.6, 1.0),
          ncol=ncol, fancybox=True, shadow=False)
# plt.title(plot_title)
plt.grid(True)
plt.show()



with open(f'{path}/intv0_pneum.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)
    print('P(pneum|do(covid=0)', loaded_dict)



with open(f'{path}/intv1_pneum.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)
    print('P(pneum|do(covid=1)', loaded_dict)

